/*    Copyright 2010 Tobias Marschall
 *
 *    This file is part of MoSDi.
 *
 *    MoSDi is free software: you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation, either version 3 of the License, or
 *    (at your option) any later version.
 *
 *    MoSDi is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with MoSDi.  If not, see <http://www.gnu.org/licenses/>.
 */

package mosdi.paa;

import java.util.Arrays;

import mosdi.util.ArrayUtils;

/** Abstract base class for an probabilistic arithmetic automaton. This class 
 *  precomputes and caches all transition probabilities, emission probabilities
 *  and all arithmetic computations. If this behaviour is not desired, the methods
 *  getTargets(), getTargetProbabilities(), getEmissions(), getEmissionProbabilities(),
 *  and performOperationFast() must be overwritten.
 *  TODO: arithmetic operations are NOT cached, are they?!?
 */
public abstract class PAA {
	/** Preprocessed table of target states to be used 
	 *  in getTargets(). */
	private int[][] targets;
	/** Preprocessed table of transition probabilities to be used 
	 *  in getTargetProbabilities(). */
	private double[][] targetProbabilities;
	/** Preprocessed table of emission probabilities to be used 
	 *  in getEmissions(). */
	private int[][] emissions;
	/** Preprocessed table of emission probabilities to be used 
	 *  in getEmissionProbabilities(). */
	private double[][] emissionProbabilities;

	/** The size of set Q. */
	public abstract int getStateCount();
	/** The size of set N. */
	public abstract int getValueCount();
	/** The size of set E. */
	public abstract int getEmissionCount();

	public abstract int getStartValue();
	public abstract int getStartState();

	/** Probability of going from state q to state p. */
	public abstract double transitionProbability(int state, int targetState);

	/** Probability that states q emits value z. */
	public abstract double emissionProbability(int state, int emission);

	public abstract int performOperation(int state, int value, int emission);
	
	/** Returns the operation if it can be cast into a SimpleOperation, otherwise returns null. */
	public SimpleOperation getOperation() { return null; }

	/** Returns an array of targets of a given state.
	 *  j in getTargets(i) iff transitionProbability(i,j)>0.0.
	 */
	protected int[] getTargets(int state) {
		if (targets==null) {
			targets = new int[getStateCount()][];
			for (int i=0; i<getStateCount(); ++i) {
				targets[i] = new int[getStateCount()];
				int n = 0;
				for (int j=0; j<getStateCount(); ++j) {
					if (transitionProbability(i,j)>0.0) targets[i][n++] = j;
				}
				targets[i] = Arrays.copyOf(targets[i], n);
			}
		}
		return targets[state];
	}

	/** Returns array of transition probabilities for outgoing from 
	 *  given state. The returned array is consistent with the result
	 *  of getTargets(), i.e. getTargetProbabilities(q)[i] is the probability
	 *  if going from state q to state getTargets(q)[i].
	 */
	protected double[] getTargetProbabilities(int state) {
		if (targetProbabilities==null) {
			targetProbabilities = new double[getStateCount()][];
			for (int i=0; i<getStateCount(); ++i) {
				int[] t = getTargets(i);
				targetProbabilities[i] = new double[t.length];
				int n = 0;
				for (int j : t) targetProbabilities[i][n++] = transitionProbability(i,j);
			}
		}
		return targetProbabilities[state];
	}

	/** Returns an array of possible emissions of a given state, i.e.
	 *  z in getEmissions(q) iff emissionProbability(q,z)>0.0.
	 */
	protected int[] getEmissions(int state) {
		if (emissions==null) {
			emissions = new int[getStateCount()][];
			for (int i=0; i<getStateCount(); ++i) {
				emissions[i] = new int[getStateCount()];
				int n = 0;
				for (int e=0; e<getEmissionCount(); ++e) {
					if (emissionProbability(i,e)>0.0) emissions[i][n++] = e;
				}
				emissions[i] = Arrays.copyOf(emissions[i], n);
			}
		}
		return emissions[state];
	}

	/** Returns array of emission probabilities of given state 
	 *  The returned array is consistent with the result
	 *  of getEmissions(), i.e. getEmissionProbabilities(q)[i] the 
	 *  probability of the emission e = getEmissions(q)[i].
	 */
	protected double[] getEmissionProbabilities(int state) {
		if (emissionProbabilities==null) {
			emissionProbabilities = new double[getStateCount()][];
			for (int i=0; i<getStateCount(); ++i) {
				int[] e = getEmissions(i);
				emissionProbabilities[i] = new double[e.length];
				int n = 0;
				for (int z : e) emissionProbabilities[i][n++] = emissionProbability(i,z);
			}
		}
		return emissionProbabilities[state];
	}

	public double[][] stateValueStartDistribution() {
		double[][] result = new double[getStateCount()][getValueCount()];
		result[getStartState()][getStartValue()] = 1.0;
		return result;
	}

	/** Perform one time step. */
	public double[][] updateStateValueDistribution(double[][] distribution) {
		double[][] result = new double[getStateCount()][getValueCount()];
		// iterate over all source states
		for (int sourceState=0; sourceState<getStateCount(); ++sourceState) {
			int[] targetStates = getTargets(sourceState);
			double[] probs = getTargetProbabilities(sourceState);
			// iterate over all target states that can be reached from source state
			for (int i=0; i<targetStates.length; ++i) {
				int targetState = targetStates[i];
				double transitionProb = probs[i];
				// use simplified recurrence of emissions are deterministic
				if (this instanceof DeterministicEmitter) {
					int emission = ((DeterministicEmitter)this).getEmission(targetState);
					// iterate over all possible former values
					for (int value=0; value<getValueCount(); ++value) {
						result[targetState][performOperation(targetState, value, emission)]+=
							distribution[sourceState][value] * transitionProb;
					}
				} else {
					int[] emissions = getEmissions(targetState);
					double[] emissionProbs = getEmissionProbabilities(targetState);
					// iterate over all emissions that can be emitted by target state
					for (int j=0; j<emissions.length; ++j) {
						int emission = emissions[j];
						double emissionProb = emissionProbs[j];
						// iterate over all possible former values
						for (int value=0; value<getValueCount(); ++value) {
							result[targetState][performOperation(targetState, value, emission)]+=
								distribution[sourceState][value] * transitionProb * emissionProb;
						}
					}
				}
			}
		}
		return result;
	}

	/** Perform one time step, just considering states, ignoring values. */
	public double[] updateStateDistribution(double[] distribution) {
		double[] result = new double[getStateCount()];
		// iterate over all source states
		for (int sourceState=0; sourceState<getStateCount(); ++sourceState) {
			int[] targetStates = getTargets(sourceState);
			double[] probs = getTargetProbabilities(sourceState);
			// iterate over all target states that can be reached from source state
			for (int i=0; i<targetStates.length; ++i) {
				int targetState = targetStates[i];
				double transitionProb = probs[i];
				result[targetState] += distribution[sourceState] * transitionProb;
			}
		}
		return result;
	}

	/** Helper method needed for the doubling technique. */
	private double[][][] joinDistributions(double[][][] dist0, double[][][] dist1) {
		int stateCount = getStateCount();
		int valueCount = getValueCount();
		SimpleOperation op = getOperation();
		double[][][] resultDist = new double[stateCount][stateCount][valueCount];
		for (int state0=0; state0<stateCount; ++state0) {
			for (int state1=0; state1<stateCount; ++state1) {
				for (int state2=0; state2<stateCount; ++state2) {
					for (int value1=0; value1<valueCount; ++value1) {
						for (int value2=0; value2<valueCount; ++value2) {
							resultDist[state0][state2][op.joinValues(value1, value2)] +=
								dist0[state0][state1][value1] * dist1[state1][state2][value2];
						}
					}
				}
			}
		}
		return resultDist;
	}

	/** Uses the doubling algorithm to compute the state-value distribution after
	 *  a given number of steps. */
	public double[][] stateValueDistributionViaDoubling(long steps) {
		if ((getOperation()==null) || !(this instanceof DeterministicEmitter)) {
			throw new IllegalStateException("Not yet implemented.");
		}
		int stateCount = getStateCount();
		int valueCount = getValueCount();
		int bitsNeeded = (int)(Math.ceil(Math.log(steps) / Math.log(2))+1.0);
		// distributions dist[k][state0][state1][value] is the probability of
		// going from state0 to state1 in 2^k steps and having changed the start 
		// value into value.
		double[][][][] dists = new double[bitsNeeded][][][];

		// initialize first distribution
		dists[0] = new double[stateCount][stateCount][valueCount];
		for (int sourceState=0; sourceState<stateCount; ++sourceState) {
			int[] targetStates = getTargets(sourceState);
			double[] targetProbabilities = getTargetProbabilities(sourceState);
			for (int i=0; i<targetStates.length; ++i) {
				int targetState = targetStates[i];
				int value = performOperation(targetState, getStartValue(), ((DeterministicEmitter)this).getEmission(targetState));
				dists[0][sourceState][targetState][value] = targetProbabilities[i];
			}
		}
		// create all needed distributions
		for (int bit=1; bit<bitsNeeded; ++bit) {
			if ((1L<<bit)>steps) break;
			dists[bit] = joinDistributions(dists[bit-1],dists[bit-1]);
			// do we still need the previous distribution?
			if (((1L<<(bit-1))&steps)==0) dists[bit-1]=null;
		}
		// join distributions to create final
		double[][][] targetDist = null;
		for (double[][][] d : dists) {
			if (d==null) continue;
			if (targetDist==null) {
				targetDist=d;
			} else {
				targetDist=joinDistributions(targetDist,d);
			}
		}
		double[][] startDist = stateValueStartDistribution();
		double[][] result = new double[stateCount][valueCount];
		for (int startState=0; startState<stateCount; ++startState) {
			double p = startDist[startState][getStartValue()]; 
			if (p==0.0) continue;
			for (int endState=0; endState<stateCount; ++endState) {
				for (int endValue=0; endValue<valueCount; ++endValue) {
					result[endState][endValue] += p*targetDist[startState][endState][endValue];
				}	
			}
		}
		return result;
	}

	public double[][] computeStateValueDistribution(long iterations) {
		double[][] distribution = stateValueStartDistribution();
		// iterate over all times
		for  (long t=0; t<iterations; ++t) {
			distribution = updateStateValueDistribution(distribution);
		}
		return distribution;
	}

	public double[] toValueDistribution(double[][] stateValueDistribution) {
		double[] result = new double[getValueCount()];
		for (int value=0; value<getValueCount(); ++value) {
			for (int state=0; state<getStateCount(); ++state) {
				result[value]+=stateValueDistribution[state][value];
			}
		}
		return result;
	}
	
	public double[] computeValueDistribution(int iterations) {
		return toValueDistribution(computeStateValueDistribution(iterations));
	}

	/** Updates joint distribution of state and value until convergence to the 
	 *  given accuracy has been reached. 
	 *  
	 *  @throws NoConvergenceException if stepLimit has been exceeded.
	 */
	public double[][] convergeToStateValueEquilibrium(double accuracy, int stepLimit) {
		double[][] distribution = stateValueStartDistribution();
		double maxDiff = Double.POSITIVE_INFINITY;
		for  (int t=0; (t<=getStateCount()) && (maxDiff>accuracy); ++t) {
			if (t==stepLimit) throw new NoConvergenceException();
			double[][] newDistribution = updateStateValueDistribution(distribution);
			// check if equilibrium is reached
			maxDiff = 0.0;
			for (int state=0; state<getStateCount(); ++state) {
				for (int value=0; value<getValueCount(); ++value) {
					double diff = Math.abs(1.0 - newDistribution[state][value]/distribution[state][value]);
					if (diff>maxDiff) maxDiff=diff;
				}
			}
			distribution = newDistribution;
		}
		return distribution;
	}

	/** Updates state distribution until convergence to the given accuracy has been reached. 
	 *  
	 *  @throws NoConvergenceException if stepLimit has been exceeded.
	 */
	public double[] convergeToStateEquilibrium(double accuracy, int stepLimit) {
		double[] distribution = new double[getStateCount()];
		int state = 0;
		for (double[] valueDist : stateValueStartDistribution()) {
			distribution[state++] = ArrayUtils.sum(valueDist);
		}
		double maxDiff = Double.POSITIVE_INFINITY;
		for  (int t=0; maxDiff>accuracy; ++t) {
			if (t==stepLimit) throw new NoConvergenceException();
			double[] newDistribution = updateStateDistribution(distribution);
			// check if equilibrium is reached
			maxDiff = 0.0;
			for (state=0; state<getStateCount(); ++state) {
				double diff = Math.abs(1.0 - newDistribution[state]/distribution[state]);
				if (diff>maxDiff) maxDiff=diff;
			}
			distribution = newDistribution;
		}
		return distribution;
	}

}
